import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm.auto as tqdm
import wandb

name_map = {
    "centralized": "Centralized Shield",
    "slugs_centralized": "Centralized Shield",
    "decentralized": "Decentralized Shield",
    "slugs_decentralized": "Decentralized Shield",
    "none": "No Shield"
}


def make_history_and_config_table(run):
    run_dataframe = run.history()[["train/episode_rewards_0"]]
    run_dataframe["shield"] = run.config["shield"]
    run_dataframe["randomize_starts"] = run.config["randomize_starts"]
    run_dataframe["map_name"] = run.config["grid_world_map_name"]
    run_dataframe["timestep"] = run_dataframe.index
    return run_dataframe


def process_pd_runs(pd_runs):
    pd_runs = pd_runs.groupby(["map_name", "randomize_starts", "shield", "timestep"]).agg(
        {"train/episode_rewards_0": ["mean", "std", "count"]})
    pd_runs[("train/episode_rewards_0", "stderr")] = 1.96 * pd_runs[("train/episode_rewards_0", "std")] / np.sqrt(
        pd_runs[("train/episode_rewards_0", "count")])
    pd_runs = pd.concat((pd_runs, pd_runs.index.to_frame()), axis=1)
    pd_runs["steps"] = pd_runs["timestep"] * int(2.5e6) / 200

    return pd_runs


def plot_pd_runs(pd_runs, filename):
    plt.figure(figsize=(25, 10))
    for rownum, rand_start in enumerate(["False", "True"]):
        for colnum, map_name in enumerate(["ISR", "MIT", "Pentagon", "SUNY"]):
            ax = plt.subplot(2, 4, (rownum * 4) + colnum + 1)
            for shield in ("none", "slugs_centralized", "slugs_decentralized"):
                these_runs = pd_runs[
                    (pd_runs["map_name"] == map_name) & (pd_runs["randomize_starts"] == rand_start) & (
                                pd_runs["shield"] == shield)]
                plt_params = {}
                if rownum == 0 and colnum == 0:
                    plt_params["label"] = name_map[shield]
                ax.plot(these_runs["steps"], these_runs[("train/episode_rewards_0", "mean")], **plt_params)
                ax.fill_between(these_runs["steps"], these_runs[("train/episode_rewards_0", "mean")] - these_runs[
                    ("train/episode_rewards_0", "stderr")],
                                these_runs[("train/episode_rewards_0", "mean")] + these_runs[
                                    ("train/episode_rewards_0", "stderr")], alpha=0.5)
                ax.set_title(f"{map_name}, {'Random' if rand_start == 'True' else 'Fixed'} Starts")
                if rownum == 1:
                    ax.set_xlabel("Training Steps")
                if colnum == 0:
                    ax.set_ylabel("Reward")
    plt.figlegend(loc="lower center", ncol=3)
    plt.savefig(filename, dpi=300)


if __name__ == '__main__':
    api = wandb.Api()

    all_runs = list(
        api.runs("dmelcer9/Centralized-Verification-Slugs-Sweep", filters={"config.evaluation_shield": None}))
    pd_runs = pd.concat([make_history_and_config_table(r) for r in tqdm.tqdm(all_runs)])
    pd_runs = process_pd_runs(pd_runs)
    plot_pd_runs(pd_runs, "slugs_sweep.png")
